# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""T2TModel Base Class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import contextlib
import copy
import functools
import math
import os
import time
import six

from tensor2tensor.data_generators import multi_problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators.problem import problem_hparams_to_features
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import modalities
from tensor2tensor.layers.common_attention import mixed_precision_is_enabled
from tensor2tensor.utils import beam_search
from tensor2tensor.utils import contrib
from tensor2tensor.utils import decoding
from tensor2tensor.utils import expert_utils as eu
from tensor2tensor.utils import hparams_lib
from tensor2tensor.utils import learning_rate
from tensor2tensor.utils import metrics
from tensor2tensor.utils import mlperf_log
from tensor2tensor.utils import optimize
from tensor2tensor.utils import quantization
from tensor2tensor.utils import registry
from tensor2tensor.utils import scheduled_sampling

import tensorflow.compat.v1 as tf

from tensorflow.python.layers import base
from tensorflow.python.ops import inplace_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_inspect as inspect

_no_problem_err_str = (
    'The default implementation of %s requires that the '
    'model be used with a Problem. If using a Problem, augment the '
    'hparams object with trainer_lib.add_problem_hparams. If not, '
    'override %s.'
)
_no_problem_err = lambda method_name: _no_problem_err_str % (
    method_name,
    method_name,
)


def _flatten_dict(original_dict):
    """Flatten dict of dicts into a single dict with appropriate prefixes.

  Handles only 2 levels of nesting in the original dict.

  Args:
    original_dict: Dict which may contain one or more dicts.
  Returns:
    flat_dict: Dict without any nesting. Any dicts in the original dict have
      their keys as prefixes in the new dict.
  Raises:
    ValueError if the original dict has more than two levels of nesting.
  """
    flat_dict = {}
    for key, value in original_dict.items():
        if isinstance(value, dict):
            for name, tensor in value.items():
                if isinstance(tensor, dict):
                    raise ValueError(
                        'flatten_dict only handles 2 levels of nesting.'
                    )
                flat_key = '__' + key + '_' + name
                flat_dict[flat_key] = tensor
        else:
            flat_dict[key] = value

    return flat_dict


def _unflatten_dict(flat_dict, prefixes):
    """Returns a dict of dicts if any prefixes match keys in the flat dict.

    The function handles the case where the prefix may not be a dict.

  Args:
    flat_dict: A dict without any nesting.
    prefixes: A list of strings which may have been dicts in the
      original structure.

  """
    original_dict = {}
    for key, value in flat_dict.items():
        prefix_found = False
        for prefix in prefixes:
            full_prefix = '__' + prefix + '_'
            if key.startswith(full_prefix):
                # Add a dict to the original dict with key=prefix
                if prefix not in original_dict:
                    original_dict[prefix] = {}
                original_dict[prefix][key[len(full_prefix) :]] = value
                prefix_found = True
                break
        if not prefix_found:
            # No key matched a prefix in the for loop.
            original_dict[key] = value

    return original_dict


class T2TModel(base.Layer):
    """Abstract base class for models.

  `T2TModel` has three typical usages:

  1. Estimator: The method `make_estimator_model_fn` builds a `model_fn` for
     the tf.Estimator workflow of training, evaluation, and prediction.
     It performs the method `call`, which performs the core computation,
     followed by `estimator_spec_train`, `estimator_spec_eval`, or
     `estimator_spec_predict` depending on the tf.Estimator mode.
  2. Layer: The method `call` enables `T2TModel` to be used a callable by
     itself. It calls the following methods:

     * `bottom`, which transforms features according to `problem_hparams`' input
       and target `Modality`s;
     * `body`, which takes features and performs the core model computation to
        return output and any auxiliary loss terms;
     * `top`, which takes features and the body output, and transforms them
       according to `problem_hparams`' input and target `Modality`s to return
       the final logits;
     * `loss`, which takes the logits, forms any missing training loss, and sums
       all loss terms.
  3. Inference: The method `infer` enables `T2TModel` to make sequence
     predictions by itself.

  Subclasses generally only need to override `body`.
  """

    REGISTERED_NAME = None  # Updated on registration.

    def __init__(
        self,
        hparams,
        mode = tf.estimator.ModeKeys.TRAIN,
        problem_hparams = None,
        data_parallelism = None,
        decode_hparams = None,
        **kwargs
    ):
        """Creates a T2TModel.

    Args:
      hparams: HParams, model hyperparameters.
      mode: tf.estimator.ModeKeys, the execution mode.
      problem_hparams: HParams, hyperparameters for the
        Problem. If provided here or in hparams.problem_hparams, the model will
        automatically determine bottom, top, and loss methods. If not provided,
        calling the model will only invoke body.
      data_parallelism: a expert_utils.Parallelism object,
        specifies devices for data parallelism.
      decode_hparams: a hyperparameter object with decoding parameters.
        See decoding.decode_hparams.
      **kwargs: arguments to pass to base.Layer constructor.
    """
        # Determine name first: use registered name if possible, class name else.
        default_name = registry.default_name(type(self))
        name = self.REGISTERED_NAME or default_name
        super(T2TModel, self).__init__(
            trainable = mode == tf.estimator.ModeKeys.TRAIN,
            name = name,
            **kwargs
        )

        if not problem_hparams and hasattr(hparams, 'problem_hparams'):
            problem_hparams = hparams.problem_hparams
        self._problem_hparams = problem_hparams

        # Setup hparams
        hparams = hparams_lib.copy_hparams(hparams)
        if (
            self._problem_hparams
            and hparams.shared_embedding_and_softmax_weights
        ):
            # If vocabularies differ, unset shared_embedding_and_softmax_weights.
            input_vocab_size = self._problem_hparams.vocab_size.get('inputs')
            target_vocab_size = self._problem_hparams.vocab_size.get('targets')
            if input_vocab_size is not None and hasattr(
                hparams, 'vocab_divisor'
            ):
                input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
            if target_vocab_size is not None and hasattr(
                hparams, 'vocab_divisor'
            ):
                target_vocab_size += (
                    -target_vocab_size
                ) % hparams.vocab_divisor
            if (
                input_vocab_size is not None
                and target_vocab_size is not None
                and input_vocab_size != target_vocab_size
            ):
                log_info('Unsetting shared_embedding_and_softmax_weights.')
                hparams.shared_embedding_and_softmax_weights = 0

            if hparams.hidden_size:
                hidden_size = hparams.hidden_size
            else:
                hidden_size = 1024
            mlperf_log.transformer_print(
                key = mlperf_log.MODEL_HP_EMBEDDING_SHARED_WEIGHTS,
                value = {
                    'vocab_size': target_vocab_size,
                    'hidden_size': hidden_size,
                },
                hparams = hparams,
            )

        if self._problem_hparams:
            for feature_name, modality in six.iteritems(
                self._problem_hparams.modality
            ):
                # If prepend mode, set weights_fn to appropriately handle it.
                if modality in (
                    modalities.ModalityType.CTC_SYMBOL,
                    modalities.ModalityType.IDENTITY_SYMBOL,
                    modalities.ModalityType.SYMBOL,
                    modalities.ModalityType.SYMBOL_ONE_HOT,
                ):
                    if (
                        hparams.prepend_mode == 'prepend_inputs_full_attention'
                        or (
                            hparams.prepend_mode
                            == 'prepend_inputs_masked_attention'
                            and mode != tf.estimator.ModeKeys.TRAIN
                        )
                    ):
                        weights_fn = (
                            common_layers.weights_prepend_inputs_to_targets
                        )
                        hparams.weights_fn[feature_name] = weights_fn

        self._original_hparams = hparams
        self.set_mode(mode)

        self._decode_hparams = hparams_lib.copy_hparams(
            decode_hparams or decoding.decode_hparams()
        )
        self._data_parallelism = data_parallelism or eu.Parallelism([''])
        self._num_datashards = self._data_parallelism.n
        self._ps_devices = self._data_parallelism.ps_devices
        self._eager_var_store = create_eager_var_store()
        if not common_layers.is_xla_compiled():
            self.summarize_hparams()
        self._variable_scopes = {}

    def _add_variable_scope(self, key, vs):
        if key not in self._variable_scopes:
            self._variable_scopes[key] = vs

    def summarize_hparams(self):
        def create_hparams_summary(hparams, name):
            hparams_strs = [
                tf.convert_to_tensor([k, str(v)])
                for k, v in hparams.values().items()
            ]
            tf.summary.text(name, tf.cast(tf.stack(hparams_strs), tf.string))

        create_hparams_summary(self._hparams, '%s_hparams' % self.name)
        if self._problem_hparams:
            create_hparams_summary(
                self._problem_hparams, '%s_problem_hparams' % self.name
            )

    # Replace the two methods below in order to add custom SessionRunHooks to
    # the training procedure.
    @staticmethod
    def train_hooks(hook_context):
        return []

    @staticmethod
    def eval_hooks(hook_context):
        return []

    @property
    def hparams(self):
        return self._hparams

    @property
    def problem_hparams(self):
        return self._problem_hparams

    @property
    def is_training(self):
        return self._hparams.mode == tf.estimator.ModeKeys.TRAIN

    @property
    def is_predicting(self):
        return self._hparams.mode == tf.estimator.ModeKeys.PREDICT

    @property
    def has_input(self):
        if self._problem_hparams:
            return 'inputs' in self._problem_hparams.modality
        else:
            return True

    @property
    def _custom_getter(self):
        if self.hparams.weight_dtype == 'bfloat16':
            if self.hparams.optimizer != 'Adafactor':
                raise NotImplementedError(
                    'weight_dtype=bfloat16 only implemented with Adafactor optimizer'
                )
            activation_dtype = tf.float32
            if self.hparams.activation_dtype == 'bfloat16':
                activation_dtype = tf.bfloat16
            return quantization.EighthPowerEncoding().custom_getter(
                activation_dtype = activation_dtype
            )
        elif self.hparams.activation_dtype == 'bfloat16':
            return quantization.bfloat16_activations_var_getter
        elif mixed_precision_is_enabled(hparams = self.hparams):
            return quantization.float16_activations_var_getter
        else:
            return None

    @property
    def _target_modality_is_real(self):
        """Whether the target modality is real-valued."""
        vocab_size = self._problem_hparams.vocab_size['targets']
        if vocab_size is not None and hasattr(self._hparams, 'vocab_divisor'):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        modality = self._problem_hparams.modality['targets']
        modality_name = self._hparams.name.get(
            'targets', modalities.get_name(modality)
        )(self._hparams, vocab_size)
        return modality_name.startswith('real')

    def call(self, inputs, **kwargs):
        del kwargs
        features = inputs
        set_custom_getter_compose(self._custom_getter)
        tf.get_variable_scope().set_initializer(
            optimize.get_variable_initializer(self.hparams)
        )
        with self._eager_var_store.as_default():
            self._fill_problem_hparams_features(features)
            summarize_features(features, num_shards = self._num_datashards)
            sharded_features = self._shard_features(features)
            sharded_logits, losses = self.model_fn_sharded(sharded_features)
            if isinstance(sharded_logits, dict):
                concat_logits = {}
                for k, v in six.iteritems(sharded_logits):
                    concat_logits[k] = tf.concat(v, 0)
                return concat_logits, losses
            else:
                return tf.concat(sharded_logits, 0), losses

    @staticmethod
    def has_symmetric_shards(model_name):
        # model_fn is sharded symmetrically unless the model overrides body_sharded
        # method to manually control the sharding.
        model_cls = registry.model(model_name)
        return not model_cls.use_body_sharded()

    @staticmethod
    def use_body_sharded():
        return False

    def body_sharded(self, sharded_features):
        raise NotImplementedError(
            'Models that wish to manually control sharding, '
            'e.g. MoE models, should override body_sharded '
            'and set use_body_sharded to True.'
        )

    def model_fn_sharded(self, sharded_features):
        """Estimator model_fn sharded along batch dimension.

    Args:
      sharded_features: {str: [Tensor]}. Features sharded along batch dimension.
        Each list is the same length (== number of shards).

    Returns:
      sharded_logits: [Tensor]. Logits for each shard of examples.
      losses: {str: 0-D Tensor}. Loss averaged across shards.
    """
        dp = self._data_parallelism

        # [{str: Tensor}]. Transpose of 'sharded_features'.
        datashard_to_features = self._to_features_per_datashard(
            sharded_features
        )
        if self.use_body_sharded():
            if self.hparams.scheduled_sampling_prob > 0.0:
                raise NotImplementedError(
                    'Scheduled sampling for non-sharded body only.'
                )

            # MoE models override body_sharded
            transformed_features = dp(self.bottom, datashard_to_features)
            body_out = self.body_sharded(
                self._to_single_features_dict(transformed_features)
            )
            body_out, losses = self._normalize_body_output(body_out)
            if 'training' in losses:
                log_info(
                    'Skipping T2TModel top and loss because training loss '
                    'returned from body'
                )
                sharded_logits = body_out
            else:
                if isinstance(body_out, dict):
                    sharded_logits = collections.OrderedDict()
                    sharded_losses = collections.OrderedDict()
                    for k, v in sorted(six.iteritems(body_out)):
                        sharded_logits[k] = dp(
                            self.top, v, datashard_to_features
                        )
                        sharded_losses[k] = dp(
                            self.loss, sharded_logits[k], datashard_to_features
                        )
                    training_loss_dict = average_sharded_losses(
                        [
                            ({'training': l} for l in loss)
                            for loss in sharded_losses.values()
                        ]
                    )
                    losses.update(training_loss_dict)
                else:
                    sharded_logits = dp(
                        self.top, body_out, datashard_to_features
                    )
                    sharded_losses = dp(
                        self.loss, sharded_logits, datashard_to_features
                    )
                    if isinstance(sharded_losses, tuple):
                        nums, dens = sharded_losses
                        sharded_losses = zip(nums, dens)
                    training_loss_dict = average_sharded_losses(
                        [{'training': loss} for loss in sharded_losses]
                    )
                    losses.update(training_loss_dict)
        else:
            sharded_logits, sharded_losses = dp(
                self.model_fn, datashard_to_features
            )
            sharded_logits, sharded_losses = dp(
                self.maybe_scheduled_sampling,
                datashard_to_features,
                sharded_logits,
                sharded_losses,
            )
            if isinstance(sharded_logits[0], dict):
                temp_dict = {k: [] for k, _ in six.iteritems(sharded_logits[0])}
                for k, _ in six.iteritems(sharded_logits[0]):
                    for l in sharded_logits:
                        temp_dict[k].append(l[k])
                sharded_logits = temp_dict
            losses = average_sharded_losses(sharded_losses)

        return sharded_logits, losses

    def model_fn(self, features):
        with tf.variable_scope(
            tf.get_variable_scope(), use_resource = True
        ) as vs:
            self._add_variable_scope('model_fn', vs)
            transformed_features = self.bottom(features)

            if self.hparams.activation_dtype == 'bfloat16':
                for k, v in sorted(six.iteritems(transformed_features)):
                    if v.dtype == tf.float32:
                        transformed_features[k] = tf.cast(v, tf.bfloat16)

            body_out = self.body(transformed_features)
            output, losses = self._normalize_body_output(body_out)

            if 'training' in losses:
                log_info(
                    'Skipping T2TModel top and loss because training loss '
                    'returned from body'
                )
                logits = output
            else:
                logits = self.top(output, features)
                losses['training'] = 0.0
                if (
                    self._hparams.mode != tf.estimator.ModeKeys.PREDICT
                    and self._hparams.mode != 'attack'
                ):
                    losses['training'] = self.loss(logits, features)

            return logits, losses

    def bottom(self, features):
        """Transforms features to feed into body.

    Args:
      features: dict of str to Tensor. Typically it is the preprocessed data
        batch after Problem's preprocess_example().

    Returns:
      transformed_features: dict of same key-value pairs as features. The value
        Tensors are newly transformed.
    """
        if not self._problem_hparams:
            log_warn('Without a Problem, T2TModel.bottom is a passthrough.')
            return features

        transformed_features = collections.OrderedDict()
        all_previous_modalities = []
        target_modality = _create_target_modality(
            self._problem_hparams.modality
        )

        # Transform features via its corresponding modality.
        for feature_name, modality in sorted(
            six.iteritems(self._problem_hparams.modality)
        ):
            if feature_name not in features:
                tf.logging.warning(
                    'Missing feature %s - ignoring.' % feature_name
                )
                continue
            vocab_size = self._problem_hparams.vocab_size[feature_name]
            if vocab_size is not None and hasattr(
                self._hparams, 'vocab_divisor'
            ):
                vocab_size += (-vocab_size) % self._hparams.vocab_divisor
            modality_name = self._hparams.name.get(
                feature_name, modalities.get_name(modality)
            )(self._hparams, vocab_size)
            # Use if-else clauses to preserve behavior of previous changes: namely,
            # the variable scope name for the targets feature if there is only one
            # target modality; and to reuse variable scopes for only input modalities.
            if feature_name in target_modality:
                if len(target_modality) > 1:
                    variable_scope_name = '%s/%s' % (
                        modality_name,
                        feature_name,
                    )
                else:
                    variable_scope_name = modality_name
                bottom = self._hparams.bottom.get(
                    feature_name, modalities.get_targets_bottom(modality)
                )
                # TODO(aidangomez): share variables?
                with tf.variable_scope(variable_scope_name) as vs:
                    self._add_variable_scope(variable_scope_name, vs)
                    log_info(
                        "Transforming feature '%s' with %s.targets_bottom",
                        feature_name,
                        modality_name,
                    )
                    transformed_features[feature_name] = bottom(
                        features[feature_name], self._hparams, vocab_size
                    )
            else:
                bottom = self._hparams.bottom.get(
                    feature_name, modalities.get_bottom(modality)
                )
                do_reuse = modality_name in all_previous_modalities
                with tf.variable_scope(modality_name, reuse = do_reuse) as vs:
                    self._add_variable_scope(modality_name, vs)
                    log_info(
                        "Transforming feature '%s' with %s.bottom",
                        feature_name,
                        modality_name,
                    )
                    transformed_features[feature_name] = bottom(
                        features[feature_name], self._hparams, vocab_size
                    )
                all_previous_modalities.append(modality_name)

        for key in features:
            if key not in transformed_features:
                # For features without a modality, we pass them along as is
                transformed_features[key] = features[key]
            else:
                # Other features get passed along with the "raw" suffix
                transformed_features[key + '_raw'] = features[key]

        return transformed_features

    def body(self, features):
        """Computes the targets' pre-logit activations given transformed inputs.

    Most `T2TModel` subclasses will override this method.

    Args:
      features: dict of str to Tensor, where each Tensor has shape [batch_size,
        ..., hidden_size]. It typically contains keys `inputs` and `targets`.

    Returns:
      output: Tensor of pre-logit activations with shape [batch_size, ...,
              hidden_size].
      losses: Either single loss as a scalar, a list, a Tensor (to be averaged),
              or a dictionary of losses. If losses is a dictionary with the key
              "training", losses["training"] is considered the final training
              loss and output is considered logits; self.top and self.loss will
              be skipped.
    """
        raise NotImplementedError('Abstract Method')

    def _top_single(self, body_output, feature_name, features):
        if not self._problem_hparams:
            log_warn('Without a Problem, T2TModel.top is a passthrough.')
            return body_output

        modality = self._problem_hparams.modality[feature_name]
        vocab_size = self._problem_hparams.vocab_size[feature_name]
        if vocab_size is not None and hasattr(self._hparams, 'vocab_divisor'):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        name = self._hparams.name.get(
            feature_name, modalities.get_name(modality)
        )(self._hparams, vocab_size)
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE) as tm_vs:
            self._add_variable_scope(tm_vs.name, tm_vs)
            log_info('Transforming body output with %s.top', name)
            top = self._hparams.top.get(
                feature_name, modalities.get_top(modality)
            )
            top_is_pointwise = getattr(top, 'pointwise', False)
            last_only = (
                top_is_pointwise
                and self.hparams.mode == tf.estimator.ModeKeys.PREDICT
                and not self.hparams.force_full_predict
            )
            if not last_only:
                logits = top(
                    body_output,
                    features.get('targets'),
                    self._hparams,
                    vocab_size,
                )
            else:
                # Take body outputs for the last position only, and targets too.
                if 'decode_loop_step' not in features:
                    last_position_body_output = tf.expand_dims(
                        body_output[:, -1, :, :], axis = [1]
                    )
                    last_position_targets = tf.expand_dims(
                        features['targets'][:, -1, :, :], axis = [1]
                    )
                else:
                    body_output_shape = body_output.shape.as_list()
                    last_position_body_output = tf.slice(
                        body_output,
                        [0, features['decode_loop_step'][0], 0, 0],
                        [
                            body_output_shape[0],
                            1,
                            body_output_shape[2],
                            body_output_shape[3],
                        ],
                    )
                    target_shape = features['targets'].shape.as_list()
                    last_position_targets = tf.slice(
                        features['targets'],
                        [0, features['decode_loop_step'][0], 0, 0],
                        [target_shape[0], 1, target_shape[2], target_shape[3]],
                    )
                logits = top(
                    last_position_body_output,
                    last_position_targets,
                    self._hparams,
                    vocab_size,
                )
        return logits

    def top(self, body_output, features):
        """Computes logits given body output and features.

    Args:
      body_output: dict of str to Tensor, comprising one key-value pair for each
        target. Each value denotes the target's pre-logit activations.
        Alternatively, it may be a single Tensor denoting the pre-logits for
        that target.
      features: dict of str to Tensor. Typically it is the preprocessed data
        batch after Problem's preprocess_example().

    Returns:
      logits: dict of str to Tensor, denoting each logits for each target; or
        a single Tensor denoting the logits for that target.
        When targets are generated at training time:
          logits == {
            "self_generated_targets": <generated targets tensor>
            "logits": <original logits Tensor or dict>
          }
    """
        if isinstance(body_output, dict):
            logits = {}
            for k, v in six.iteritems(body_output):
                # TODO(aidangomez): share variables here?
                with tf.variable_scope(k) as top_vs:
                    self._add_variable_scope('top_%s' % k, top_vs)
                    logits[k] = self._top_single(v, k, features)
            return logits
        else:
            return self._top_single(body_output, 'targets', features)

    def _loss_single(self, logits, feature_name, feature, weights = None):
        # The current bfloat16 version still uses float32 for most parts of backward
        # propagation to keep model quality, so cast back before computing the loss
        # value.
        if not self._problem_hparams:
            log_warn(_no_problem_err('loss'))
            return (
                tf.constant(0.0, dtype = tf.float32),
                tf.constant(1.0, dtype = tf.float32),
            )

        # Calculate loss contribution.
        modality = self._problem_hparams.modality[feature_name]
        vocab_size = self._problem_hparams.vocab_size[feature_name]
        if vocab_size is not None and hasattr(self._hparams, 'vocab_divisor'):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        loss = self._hparams.loss.get(
            feature_name, modalities.get_loss(modality)
        )
        targets_weights_fn = self._hparams.weights_fn.get(
            'targets', modalities.get_weights_fn(modality)
        )
        if weights is None:
            loss_num, loss_den = loss(
                logits,
                feature,
                self._hparams,
                vocab_size,
                weights_fn = targets_weights_fn,
            )
        else:

            def weights_fn(labels):
                """Per-token weights for loss."""
                # Use target_weights_fn() given by modality as well as explicitly given
                # weights.
                modality_weights = targets_weights_fn(labels)

                # Broadcast 'weights' along minor dimensions (TF's default is major).
                explicit_weights = weights
                if len(explicit_weights.shape) < len(modality_weights.shape):
                    explicit_weights = common_layers.expand_squeeze_to_nd(
                        weights, modality_weights.shape.ndims
                    )

                return explicit_weights * modality_weights

            # Ensure that target.modality_loss() supports "weights_fn" keyword
            # argument. If it doesn't and "weights" is specified, raise an exception.
            argument_names = inspect.getargspec(loss).args
            if 'weights_fn' not in argument_names:
                raise ValueError(
                    "Explicit 'weights' given but default loss for modality doesn't "
                    "support 'weights_fn' keyword argument: %s.loss(%s)."
                    % (modality, ', '.join(argument_names))
                )

            loss_num, loss_den = loss(
                logits,
                feature,
                self._hparams,
                vocab_size,
                weights_fn = weights_fn,
            )

        loss_num *= self._problem_hparams.loss_multiplier

        if hasattr(self.hparams, 'problem') and hasattr(
            self.hparams.problem, 'task_list'
        ):
            if weights is not None:
                raise NotImplementedError(
                    'weights not yet implemented in ' 'multitask setting.'
                )
            loss_num, loss_den, summaries = multi_problem.aggregate_task_losses(
                self.hparams,
                self._problem_hparams,
                logits,
                feature_name,
                feature,
            )

            for key, val in summaries:
                tf.summary.scalar(key, val)

        return loss_num, loss_den

    def loss(self, logits, features):
        if isinstance(logits, dict):
            losses = {}
            for k, v in six.iteritems(logits):
                losses[k] = self._loss_single(
                    v, k, features[k], weights = features.get(k + '_mask')
                )

                n, d = losses[k]
                if common_layers.should_generate_summaries():
                    tf.summary.scalar(k + '_loss', n / d)
                    tf.summary.scalar(k + '_loss_num', n)
                    tf.summary.scalar(k + '_loss_den', d)
                    if getattr(
                        self.hparams, 'visualize_logits_histogram', False
                    ):
                        hist = tf.summary.histogram
                        hist(
                            k + '_predict', tf.argmax(tf.squeeze(v), axis = -1)
                        )
                        hist(k + '_targets', features[k])

            return tf.add_n([n / d for n, d in losses.values()])
        else:
            return self._loss_single(
                logits,
                'targets',
                features['targets'],
                weights = features.get('targets_mask'),
            )

    def optimize(
        self, loss, num_async_replicas = 1, use_tpu = False, variables = None
    ):
        """Return a training op minimizing loss."""
        lr = learning_rate.learning_rate_schedule(self.hparams)
        if num_async_replicas > 1:
            log_info(
                'Dividing learning rate by num_async_replicas: %d',
                num_async_replicas,
            )
        lr /= math.sqrt(float(num_async_replicas))
        train_op = optimize.optimize(
            loss, lr, self.hparams, use_tpu = use_tpu, variables = variables
        )
        return train_op

    def set_mode(self, mode):
        """Set hparams with the given mode."""
        log_info("Setting T2TModel mode to '%s'", mode)
        hparams = hparams_lib.copy_hparams(self._original_hparams)
        hparams.add_hparam('mode', mode)
        # When not in training mode, set all forms of dropout to zero.
        if mode != tf.estimator.ModeKeys.TRAIN:
            for key in hparams.values():
                if key.endswith('dropout') or key == 'label_smoothing':
                    log_info('Setting hparams.%s to 0.0', key)
                    setattr(hparams, key, 0.0)
        self._hparams = hparams

    def prepare_features_for_infer(self, features):
        """Called before inference to allow adding infer-specific features."""
        pass

    def eval_autoregressive(self, features = None, decode_length = 50):
        """Autoregressive eval.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.

    Returns:
      logits: `Tensor`
      losses: a dictionary: {loss-name (string): floating point `Scalar`}.
          Contains a single key "training".
    """
        results = self._slow_greedy_infer(
            features, decode_length = decode_length
        )
        return results['logits'], results['losses']

    def _fill_problem_hparams_features(self, features):
        if features is not None:
            for k, v in sorted(
                six.iteritems(
                    problem_hparams_to_features(self._problem_hparams)
                )
            ):
                if k not in features:
                    features[k] = tf.constant(v, name = k)

    def infer(
        self,
        features = None,
        decode_length = 50,
        beam_size = 1,
        top_beams = 1,
        alpha = 0.0,
        use_tpu = False,
    ):
        """A inference method.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for longer translations.
      use_tpu: bool, whether to build the inference graph for TPU.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }
      if slow greedy decoding is used then the dict will also contain {
          "logits": `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
          "losses": a dictionary: {loss-name (string): floating point `Scalar`
      }
    """
        set_custom_getter_compose(self._custom_getter)
        with self._eager_var_store.as_default():
            # TODO(rsepassi): Make decoding work with real-valued model outputs
            # (i.e. if the target modality is RealModality).
            self.prepare_features_for_infer(features)
            if not self.has_input and beam_size > 1:
                log_warn('Beam searching for a model with no inputs.')
            if not self.has_input and self.hparams.sampling_method != 'random':
                log_warn('Non-random sampling for a model with no inputs.')
            self._fill_problem_hparams_features(features)

            if self._problem_hparams:
                target_modality = self._problem_hparams.modality['targets']
            if (
                target_modality == modalities.ModalityType.CLASS_LABEL
                or self._problem_hparams.get('regression_targets')
            ):
                # No use to run beam-search for classification or regression.
                beam_size = 1
            if beam_size == 1:
                log_info('Greedy Decoding')
                results = self._greedy_infer(features, decode_length, use_tpu)
            else:
                log_info('Beam Decoding with beam size %d' % beam_size)
                results = self._beam_decode(
                    features,
                    decode_length,
                    beam_size,
                    top_beams,
                    alpha,
                    use_tpu,
                )

            return results

    def _beam_decode(
        self,
        features,
        decode_length,
        beam_size,
        top_beams,
        alpha,
        use_tpu = False,
    ):
        """Beam search decoding.

    Models should ideally implement a more efficient version of this function.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for longer translations.
      use_tpu: A bool, whether to do beam decode on TPU.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search
    """
        return self._beam_decode_slow(
            features, decode_length, beam_size, top_beams, alpha, use_tpu
        )

    def _beam_decode_slow(
        self,
        features,
        decode_length,
        beam_size,
        top_beams,
        alpha,
        use_tpu = False,
    ):
        """Slow version of Beam search decoding.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for longer translations.
      use_tpu: A bool, whether to do slow beam decode on TPU.

    Returns:
      samples: an integer `Tensor`. Top samples from the beam search.

    Raises:
      NotImplementedError: If use_tpu is set to true.
    """
        batch_size = common_layers.shape_list(features['inputs'])[0]

        def symbols_to_logits_fn(ids, i = None):
            """Go from ids to logits."""
            ids = tf.expand_dims(tf.expand_dims(ids, axis = 2), axis = 3)
            ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]])
            if 'partial_targets' in features:
                pt = features['partial_targets']
                pt_length = common_layers.shape_list(pt)[1]
                pt = tf.tile(pt, [1, beam_size])
                pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1])
                ids = tf.concat([pt, ids], axis = 1)

            features['targets'] = ids
            if i is not None:
                features['decode_loop_step'] = i
            self._coverage = None
            logits, _ = self(features)  # pylint: disable=not-callable
            # now self._coverage is a coverage tensor for the first datashard.
            # it has shape [batch_size] and contains floats between 0 and
            # source_length.
            if self._problem_hparams:
                modality = self._problem_hparams.modality['targets']
                top = self._hparams.top.get(
                    'targets', modalities.get_top(modality)
                )
                if getattr(top, 'pointwise', False):
                    return tf.squeeze(logits, axis = [1, 2, 3])
            # -1 due to the pad above.
            current_output_position = common_layers.shape_list(ids)[1] - 1
            logits = logits[:, current_output_position, :, :]
            return tf.squeeze(logits, axis = [1, 2])

        def _clone_examples_for_beam(old_feature, n):
            """Clone each example n times."""
            old_shape = common_layers.shape_list(old_feature)
            assert len(old_shape) >= 1

            # Expand the inputs in to the beam size.
            feature = tf.expand_dims(old_feature, 1)
            feature = tf.tile(feature, [1, n] + [1] * (len(old_shape) - 1))
            new_shape = common_layers.shape_list(feature)
            feature = tf.reshape(
                feature, [new_shape[0] * new_shape[1]] + new_shape[2:]
            )
            return feature

        initial_ids = tf.zeros([batch_size], dtype = tf.int32)

        # Clone select features multiple times to account for beam size.
        old_features = {}
        for feature_name in ['inputs', 'knowledge']:
            if feature_name not in features:
                continue
            old_features[feature_name] = features[feature_name]
            features[feature_name] = _clone_examples_for_beam(
                features[feature_name], beam_size
            )

        vocab_size = self._problem_hparams.vocab_size['targets']
        if vocab_size is not None and hasattr(self._hparams, 'vocab_divisor'):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor

        # Setting decode length to input length + decode_length
        if 'partial_targets' not in features:
            inputs = features['inputs']
            decode_length = common_layers.shape_list(inputs)[1] + features.get(
                'decode_length', decode_length
            )
        ids, scores, _ = beam_search.beam_search(
            symbols_to_logits_fn,
            initial_ids,
            beam_size,
            decode_length,
            vocab_size,
            alpha,
            stop_early = (top_beams == 1),
            use_tpu = use_tpu,
        )

        # Set features back to the unexpanded form to not to confuse the
        # Estimator!
        features.update(old_features)

        # Return `top_beams` decodings (also remove initial id from the beam search)
        # TODO(lukaszkaiser): make it work multi-problem.
        if top_beams == 1:
            samples = ids[:, 0, 1:]
        else:
            samples = ids[:, :top_beams, 1:]

        return {'outputs': samples, 'scores': scores}

    def _greedy_infer(self, features, decode_length, use_tpu = False):
        """A greedy inference method.

    Models should ideally implement a more efficient version of this function.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      use_tpu: A bool, whether to build the inference graph for TPU.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": None
          "logits": `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
          "losses": a dictionary: {loss-name (string): floating point `Scalar`}
      }
    """
        if use_tpu:
            return self._slow_greedy_infer_tpu(features, decode_length)
        return self._slow_greedy_infer(features, decode_length)

    def _slow_greedy_infer_tpu(self, features, decode_length):
        """A slow greedy inference method on TPU.

    Quadratic time in decode_length.

    Args:
      features: An map of string to `Tensor`.
      decode_length: An integer, how many additional timesteps to decode.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": None
          "logits": `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
          "losses": a dictionary: {loss-name (string): floating point `Scalar`}
      }
    """
        if not features:
            features = {}
        inputs_old = None
        if 'inputs' in features and len(features['inputs'].shape) < 4:
            inputs_old = features['inputs']
            features['inputs'] = tf.expand_dims(features['inputs'], 2)
        if not self.has_input:
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get('inputs')
            if partial_targets is None:
                partial_targets = features['targets']
            features['partial_targets'] = tf.to_int64(partial_targets)
        # Save the targets in a var and reassign it after the tf.while loop to avoid
        # having targets being in a 'while' frame. This ensures targets when used
        # in metric functions stays in the same frame as other vars.
        targets_old = features.get('targets', None)

        target_modality = self._problem_hparams.modality['targets']

        def infer_step(i, recent_output, recent_logits, unused_loss):
            """Inference step."""
            if not tf.executing_eagerly():
                recent_output.set_shape([None, None, None, 1])
            padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
            features['targets'] = padded
            # This is inefficient in that it generates samples at all timesteps,
            # not just the last one, except if target_modality is pointwise.
            features['decode_loop_step'] = i
            samples, logits, losses = self.sample(features)
            # Concatenate the already-generated recent_output with last timestep
            # of the newly-generated samples.z
            top = self._hparams.top.get(
                'targets', modalities.get_top(target_modality)
            )
            if getattr(top, 'pointwise', False):
                cur_sample = samples[:, -1, :, :]
            else:
                cur_sample = samples[:, i, :, :]
            samples = tf.transpose(recent_output, perm = [1, 0, 2, 3])
            samples = inplace_ops.alias_inplace_update(
                samples, i, tf.to_int64(cur_sample)
            )
            samples = tf.transpose(samples, perm = [1, 0, 2, 3])
            if not tf.executing_eagerly():
                samples.set_shape([None, None, None, 1])

            # Assuming we have one shard for logits.
            recent_logits = tf.transpose(recent_logits, perm = [1, 0, 2, 3, 4])
            recent_logits = inplace_ops.alias_inplace_update(
                recent_logits, i, tf.squeeze(logits[:, -1:], axis = 1)
            )
            logits = tf.transpose(recent_logits, perm = [1, 0, 2, 3, 4])
            loss = sum([l for l in losses.values() if l is not None])
            return i + 1, samples, logits, loss

        # Create an initial output tensor. This will be passed
        # to the infer_step, which adds one timestep at every iteration.
        if 'partial_targets' in features:
            initial_output = tf.to_int64(features['partial_targets'])
            while len(initial_output.get_shape().as_list()) < 4:
                initial_output = tf.expand_dims(initial_output, 2)
            batch_size = common_layers.shape_list(initial_output)[0]
        else:
            batch_size = common_layers.shape_list(features['inputs'])[0]
            initial_output = tf.zeros((batch_size, 0, 1, 1), dtype = tf.int64)
        # Hack: foldl complains when the output shape is less specified than the
        # input shape, so we confuse it about the input shape.
        initial_output = tf.slice(
            initial_output,
            [0, 0, 0, 0],
            common_layers.shape_list(initial_output),
        )
        target_modality = self._problem_hparams.modality['targets']
        if (
            target_modality == modalities.ModalityType.CLASS_LABEL
            or self._problem_hparams.get('regression_targets')
        ):
            decode_length = 1
        else:
            if 'partial_targets' in features:
                prefix_length = common_layers.shape_list(
                    features['partial_targets']
                )[1]
            else:
                prefix_length = common_layers.shape_list(features['inputs'])[1]
            decode_length = prefix_length + decode_length

        # Initial values of result, logits and loss.
        result = tf.concat(
            [
                initial_output,
                tf.zeros([batch_size, decode_length, 1, 1], tf.int64),
            ],
            axis = 1,
        )
        # tensor padded to [batch_size, decode_length, 1, 1, vocab_size]
        vocab_size = self._problem_hparams.vocab_size['targets']
        if vocab_size is not None and hasattr(self._hparams, 'vocab_divisor'):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        logits = tf.zeros((batch_size, decode_length, 1, 1, vocab_size))
        if not tf.executing_eagerly():
            logits.set_shape([None, None, None, None, None])
        loss = 0.0

        def while_exit_cond(
            i, result, logits, loss
        ):  # pylint: disable=unused-argument
            """Exit the loop either if reach decode_length or EOS."""
            not_overflow = i < decode_length

            if self._problem_hparams.stop_at_eos:

                def fn_not_eos():
                    # Check if the last predicted element is a EOS
                    return tf.reduce_any(
                        tf.not_equal(
                            tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID
                        )
                    )

                not_eos = tf.cond(
                    # We only check for early stopping if there is at least 1 element (
                    # otherwise not_eos will crash).
                    tf.not_equal(i, 0),
                    fn_not_eos,
                    lambda: True,
                )

                return tf.cond(
                    tf.equal(batch_size, 1),
                    # If batch_size == 1, we check EOS for early stopping.
                    lambda: tf.logical_and(not_overflow, not_eos),
                    # Else, just wait for max length
                    lambda: not_overflow,
                )
            return not_overflow

        _, result, logits, loss = tf.while_loop(
            while_exit_cond,
            infer_step,
            [tf.constant(0), result, logits, loss],
            shape_invariants = [
                tf.TensorShape([]),
                tf.TensorShape([batch_size, decode_length, 1, 1]),
                tf.TensorShape([batch_size, decode_length, 1, 1, vocab_size]),
                tf.TensorShape([]),
            ],
            back_prop = False,
            parallel_iterations = 1,
        )
        if inputs_old is not None:  # Restore to not confuse Estimator.
            features['inputs'] = inputs_old
        # Reassign targets back to the previous value.
        if targets_old is not None:
            features['targets'] = targets_old
        losses = {'training': loss}
        if 'partial_targets' in features:
            partial_target_length = common_layers.shape_list(
                features['partial_targets']
            )[1]
            result = tf.slice(
                result, [0, partial_target_length, 0, 0], [-1, -1, -1, -1]
            )
        return {
            'outputs': result,
            'scores': None,
            'logits': logits,
            'losses': losses,
        }

    def _slow_greedy_infer(self, features, decode_length):
        """A slow greedy inference method.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": None
          "logits": `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
          "losses": a dictionary: {loss-name (string): floating point `Scalar`}
      }
    """
        if not features:
            features = {}
        inputs_old = None
        if 'inputs' in features and len(features['inputs'].shape) < 4:
            inputs_old = features['inputs']
            features['inputs'] = tf.expand_dims(features['inputs'], 2)
        if not self.has_input:
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get('inputs')
            if partial_targets is None:
                partial_targets = features['targets']
            features['partial_targets'] = tf.to_int64(partial_targets)
        # Save the targets in a var and reassign it after the tf.while loop to avoid
        # having targets being in a 'while' frame. This ensures targets when used
        # in metric functions stays in the same frame as other vars.
        targets_old = features.get('targets', None)

        target_modality = self._problem_hparams.modality['targets']

        def infer_step(recent_output, recent_logits, unused_loss):
            """Inference step."""
            if not tf.executing_eagerly():
                if self._target_modality_is_real:
                    dim = self._problem_hparams.vocab_size['targets']
                    if dim is not None and hasattr(
                        self._hparams, 'vocab_divisor'
                    ):
                        dim += (-dim) % self._hparams.vocab_divisor
                    recent_output.set_shape([None, None, None, dim])
                else:
                    recent_output.set_shape([None, None, None, 1])
            padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
            features['targets'] = padded
            # This is inefficient in that it generates samples at all timesteps,
            # not just the last one, except if target_modality is pointwise.
            samples, logits, losses = self.sample(features)
            # Concatenate the already-generated recent_output with last timestep
            # of the newly-generated samples.
            top = self._hparams.top.get(
                'targets', modalities.get_top(target_modality)
            )
            if getattr(top, 'pointwise', False):
                cur_sample = samples[:, -1, :, :]
            else:
                cur_sample = samples[
                    :, common_layers.shape_list(recent_output)[1], :, :
                ]
            if self._target_modality_is_real:
                cur_sample = tf.expand_dims(cur_sample, axis = 1)
                samples = tf.concat([recent_output, cur_sample], axis = 1)
            else:
                cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis = 1))
                samples = tf.concat([recent_output, cur_sample], axis = 1)
                if not tf.executing_eagerly():
                    samples.set_shape([None, None, None, 1])

            # Assuming we have one shard for logits.
            logits = tf.concat([recent_logits, logits[:, -1:]], 1)
            loss = sum([l for l in losses.values() if l is not None])
            return samples, logits, loss

        # Create an initial output tensor. This will be passed
        # to the infer_step, which adds one timestep at every iteration.
        if 'partial_targets' in features:
            initial_output = tf.to_int64(features['partial_targets'])
            while len(initial_output.get_shape().as_list()) < 4:
                initial_output = tf.expand_dims(initial_output, 2)
            batch_size = common_layers.shape_list(initial_output)[0]
        else:
            batch_size = common_layers.shape_list(features['inputs'])[0]
            if self._target_modality_is_real:
                dim = self._problem_hparams.vocab_size['targets']
                if dim is not None and hasattr(self._hparams, 'vocab_divisor'):
                    dim += (-dim) % self._hparams.vocab_divisor
                initial_output = tf.zeros(
                    (batch_size, 0, 1, dim), dtype = tf.float32
                )
            else:
                initial_output = tf.zeros(
                    (batch_size, 0, 1, 1), dtype = tf.int64
                )
        # Hack: foldl complains when the output shape is less specified than the
        # input shape, so we confuse it about the input shape.
        initial_output = tf.slice(
            initial_output,
            [0, 0, 0, 0],
            common_layers.shape_list(initial_output),
        )
        target_modality = self._problem_hparams.modality['targets']
        if (
            target_modality == modalities.ModalityType.CLASS_LABEL
            or self._problem_hparams.get('regression_targets')
        ):
            decode_length = 1
        else:
            if 'partial_targets' in features:
                prefix_length = common_layers.shape_list(
                    features['partial_targets']
                )[1]
            else:
                prefix_length = common_layers.shape_list(features['inputs'])[1]
            decode_length = prefix_length + decode_length

        # Initial values of result, logits and loss.
        result = initial_output
        vocab_size = self._problem_hparams.vocab_size['targets']
        if vocab_size is not None and hasattr(self._hparams, 'vocab_divisor'):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        if self._target_modality_is_real:
            logits = tf.zeros((batch_size, 0, 1, vocab_size))
            logits_shape_inv = [None, None, None, None]
        else:
            # tensor of shape [batch_size, time, 1, 1, vocab_size]
            logits = tf.zeros((batch_size, 0, 1, 1, vocab_size))
            logits_shape_inv = [None, None, None, None, None]
        if not tf.executing_eagerly():
            logits.set_shape(logits_shape_inv)

        loss = 0.0

        def while_exit_cond(
            result, logits, loss
        ):  # pylint: disable=unused-argument
            """Exit the loop either if reach decode_length or EOS."""
            length = common_layers.shape_list(result)[1]

            not_overflow = length < decode_length

            if self._problem_hparams.stop_at_eos:

                def fn_not_eos():
                    return tf.not_equal(  # Check if the last predicted element is a EOS
                        tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID
                    )

                not_eos = tf.cond(
                    # We only check for early stopping if there is at least 1 element (
                    # otherwise not_eos will crash).
                    tf.not_equal(length, 0),
                    fn_not_eos,
                    lambda: True,
                )

                return tf.cond(
                    tf.equal(batch_size, 1),
                    # If batch_size == 1, we check EOS for early stopping.
                    lambda: tf.logical_and(not_overflow, not_eos),
                    # Else, just wait for max length
                    lambda: not_overflow,
                )
            return not_overflow

        result, logits, loss = tf.while_loop(
            while_exit_cond,
            infer_step,
            [result, logits, loss],
            shape_invariants = [
                tf.TensorShape([None, None, None, None]),
                tf.TensorShape(logits_shape_inv),
                tf.TensorShape([]),
            ],
            back_prop = False,
            parallel_iterations = 1,
        )
        if inputs_old is not None:  # Restore to not confuse Estimator.
            features['inputs'] = inputs_old
        # Reassign targets back to the previous value.
        if targets_old is not None:
            features['targets'] = targets_old
        losses = {'training': loss}
        if 'partial_targets' in features:
            partial_target_length = common_layers.shape_list(
                features['partial_targets']
            )[1]
            result = tf.slice(
                result, [0, partial_target_length, 0, 0], [-1, -1, -1, -1]
            )
        return {
            'outputs': result,
            'scores': None,
            'logits': logits,
            'losses': losses,
        }

    def sample(self, features):
        """Run the model and extract samples.

    Args:
      features: an map of string to `Tensor`.

    Returns:
       samples: an integer `Tensor`.
       logits: a list of `Tensor`s, one per datashard.
       losses: a dictionary: {loss-name (string): floating point `Scalar`}.
    """
        logits, losses = self(features)  # pylint: disable=not-callable
        if self._target_modality_is_real:
            return (
                logits,
                logits,
                losses,
            )  # Raw numbers returned from real modality.
        if self.hparams.sampling_method == 'argmax':
            samples = tf.argmax(logits, axis = -1)
        else:
            assert self.hparams.sampling_method == 'random'

            def multinomial_squeeze(logits, temperature = 1.0):
                logits_shape = common_layers.shape_list(logits)
                logits /= tf.reshape(
                    temperature, [-1] + [1] * (len(logits_shape) - 1)
                )
                reshaped_logits = tf.reshape(logits, [-1, logits_shape[-1]])
                choices = tf.multinomial(reshaped_logits, 1)
                choices = tf.reshape(choices, logits_shape[:-1])
                return choices

            temperature = features.get(
                'sampling_temp', self.hparams.sampling_temp
            )
            samples = multinomial_squeeze(logits, temperature)

        return samples, logits, losses

    def _shard_features(self, features):  # pylint: disable=missing-docstring
        sharded_features = {}
        for k, v in sorted(six.iteritems(features)):
            v = tf.convert_to_tensor(v)
            v_shape = common_layers.shape_list(v)
            if not v_shape:
                v = tf.expand_dims(v, axis = -1)
                v_shape = [1]
            if v_shape == [1]:
                v = tf.tile(v, tf.to_int32([self._num_datashards]))
            sharded_features[k] = self._data_parallelism(
                tf.identity, tf.split(v, self._num_datashards, 0)
            )
        return sharded_features

    def _to_features_per_datashard(self, features):
        datashard_features = []
        assert len(features[list(features.keys())[0]]) == self._num_datashards
        for d in range(self._num_datashards):
            f = {k: v[d] for k, v in six.iteritems(features)}
            datashard_features.append(f)
        return datashard_features

    def _to_single_features_dict(self, datashard_features):
        assert len(datashard_features) == self._num_datashards
        features = collections.defaultdict(list)
        for feats in datashard_features:
            for k, v in six.iteritems(feats):
                features[k].append(v)
        return features

    @staticmethod
    def get_train_hooks(model_name, hook_context):
        model_cls = registry.model(model_name)
        return model_cls.train_hooks(hook_context)

    @staticmethod
    def get_eval_hooks(model_name, hook_context):
        model_cls = registry.model(model_name)
        return model_cls.eval_hooks(hook_context)

    @staticmethod
    def make_estimator_model_fn(
        model_name, hparams, decode_hparams = None, use_tpu = False
    ):
        model_cls = registry.model(model_name)

        def wrapping_model_fn(
            features, labels, mode, params = None, config = None
        ):
            return model_cls.estimator_model_fn(
                hparams,
                features,
                labels,
                mode,
                config = config,
                params = params,
                decode_hparams = decode_hparams,
                use_tpu = use_tpu,
            )

        return wrapping_model_fn

    @classmethod
    def estimator_model_fn(
        cls,
        hparams,
        features,
        labels,
        mode,
        config = None,
        params = None,
        decode_hparams = None,
        use_tpu = False,
    ):
        """Model fn for Estimator.

    Args:
      hparams: HParams, model hyperparameters
      features: dict<str name, Tensor feature>
      labels: Tensor
      mode: tf.estimator.ModeKeys
      config: RunConfig, possibly with data_parallelism attribute
      params: dict, may include batch_size, use_tpu
      decode_hparams: HParams, used when mode == PREDICT.
      use_tpu: A bool, whether to build the inference graph for TPU.

    Returns:
      TPUEstimatorSpec if use tpu else EstimatorSpec
    """
        if mode == tf.estimator.ModeKeys.TRAIN:
            create_dummy_vars()
        hparams = hparams_lib.copy_hparams(hparams)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        reuse = tf.get_variable_scope().reuse
        model = cls(
            hparams,
            mode,
            data_parallelism = data_parallelism,
            decode_hparams = decode_hparams,
            _reuse = reuse,
        )

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            if use_tpu:
                inputs = features.get('inputs')
                if inputs is None:
                    inputs = features.get('targets')
                if inputs is None:
                    inputs = features['infer_targets']
                shape = inputs.get_shape().as_list()
                if shape[0] is None:
                    shape[0] = decode_hparams.batch_size or hparams.batch_size
                if shape[1] is None:
                    shape[1] = (
                        hparams.max_input_seq_length or hparams.max_length
                    )
                inputs.set_shape(shape)
            return model.estimator_spec_predict(features, use_tpu = use_tpu)

        # TRAIN and EVAL modes
        if (
            hparams.eval_run_autoregressive
            and mode == tf.estimator.ModeKeys.EVAL
        ):
            logits, losses_dict = model.eval_autoregressive(features)
        else:
            logits, losses_dict = model(
                features
            )  # pylint: disable=not-callable

        # Support model-generated labels by overriding features["targets"] with
        # logits["self_generated_targets"].
        if isinstance(logits, dict) and 'self_generated_targets' in logits:
            # Overwrite 'features["targets"]' and 'labels'
            # by logits["self_generated_targets"].
            tf.logging.info('Replacing targets with model-provided targets.')
            features['targets'] = labels = logits.pop('self_generated_targets')
            assert list(logits.keys()) == ['logits'], (
                # See "Returns" in the "top" method docstring for the expected
                # "logits" format when targets are generated at training time.
                "Expect only key 'logits' when there is 'self_generated_targets'. "
                'Found {}'.format(logits.keys())
            )
            # Recover the original logits tensor from the logits dict.
            logits = logits['logits']  # Can be a tf.Tensor or a dict.

        # Set known shapes
        if common_layers.is_xla_compiled():
            if isinstance(logits, dict):
                for k, v in sorted(six.iteritems(logits)):
                    if 'scalar/' in k:
                        continue

                    shape = v.get_shape().as_list()
                    if shape[0] is None:
                        shape[0] = params['batch_size']
                    if shape[1] is None:
                        shape[1] = hparams.max_length
                    v.set_shape(shape)
            else:
                shape = logits.get_shape().as_list()
                if shape[0] is None:
                    shape[0] = params['batch_size']
                if shape[1] is None:
                    shape[1] = hparams.max_length
                logits.set_shape(shape)

        assert 'training' in losses_dict

        # Attack mode
        if mode == 'attack':
            return logits

        # Summarize losses
        model._summarize_losses(losses_dict)  # pylint: disable=protected-access

        # Accumulate losses
        loss = sum(losses_dict[key] for key in sorted(losses_dict.keys()))

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            return model.estimator_spec_eval(
                features, logits, labels, loss, losses_dict
            )

        # TRAIN mode
        assert mode == tf.estimator.ModeKeys.TRAIN
        num_async_replicas = 1
        if config and not use_tpu:
            num_async_replicas = config.t2t_device_info['num_async_replicas']
        return model.estimator_spec_train(
            loss, num_async_replicas = num_async_replicas, use_tpu = use_tpu
        )

    def initialize_from_ckpt(self, ckpt_dir):
        return initialize_from_ckpt(
            ckpt_dir = ckpt_dir, hparams = self._hparams
        )

    def create_train_host_call(self):
        return create_host_call(self.hparams.model_dir)

    def create_eval_host_call(self):
        eval_dir = os.path.join(
            self.hparams.model_dir, self.hparams.get('eval_dir_name', 'eval')
        )
        return create_host_call(eval_dir)

    def estimator_spec_train(
        self, loss, num_async_replicas = 1, use_tpu = False
    ):
        """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode."""
        train_op = self.optimize(
            loss, num_async_replicas = num_async_replicas, use_tpu = use_tpu
        )

        if use_tpu:
            if self._hparams.warm_start_from:

                def scaffold_fn():
                    self.initialize_from_ckpt(self._hparams.warm_start_from)
                    return tf.train.Scaffold()

            else:
                scaffold_fn = None

            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = self.create_train_host_call()
            else:
                host_call = None

            remove_summaries()

            return contrib.tpu().TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss = loss,
                train_op = train_op,
                host_call = host_call,
                scaffold_fn = scaffold_fn,
            )
        else:
            if self._hparams.warm_start_from:
                self.initialize_from_ckpt(self._hparams.warm_start_from)

            # When loading weights from a pre-trained model, you want to be able to
            # load separate weights into the encoder and decoder.
            if self._hparams.warm_start_from_second:
                self.initialize_from_ckpt(self._hparams.warm_start_from_second)

            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN, loss = loss, train_op = train_op
            )

    def estimator_spec_eval(self, features, logits, labels, loss, losses_dict):
        """Constructs `tf.estimator.EstimatorSpec` for EVAL (evaluation) mode."""
        del losses_dict
        hparams = self.hparams

        if not hasattr(hparams, 'problem'):
            raise NotImplementedError(_no_problem_err('estimator_spec_eval'))

        problem = hparams.problem

        if common_layers.is_xla_compiled():
            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = self.create_eval_host_call()
            else:
                host_call = None

            remove_summaries()

            eval_metrics_fn = create_tpu_eval_metrics_fn(problem, hparams)

            batch_size = [
                feature.shape.as_list()[0]
                for _, feature in features.items()
                if feature.shape.ndims
            ][0]

            # Add batch dimension to all features since tpu requires the batch
            # dimension on all tensors.
            for name, feature in features.items():
                if not feature.shape.as_list():
                    # All features must have a batch dimension
                    feature = tf.tile(tf.expand_dims(feature, 0), [batch_size])
                features[name] = feature

            eval_metrics_fn_args = dict(
                logits = logits,  # possibly a dict
                labels = labels,
                features = features,  # dict
            )

            eval_metrics_fn_flat_args = _flatten_dict(eval_metrics_fn_args)
            return contrib.tpu().TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                eval_metrics = (eval_metrics_fn, eval_metrics_fn_flat_args),
                host_call = host_call,
                loss = loss,
            )
        else:
            task_list = [problem]
            if hasattr(problem, 'task_list'):
                task_list = problem.task_list

            eval_metrics_fns = metrics.create_evaluation_metrics(
                task_list, hparams
            )
            eval_metrics = {}
            for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
                if isinstance(logits, dict):
                    # the key is located in the center of metric_name: "metrics-%s/%s/%s"
                    k = metric_name.split('/')[1]
                    if k in logits:
                        eval_metrics[metric_name] = metric_fn(
                            logits[k], features, features[k]
                        )
                    else:
                        # We do not make it an error because we sometimes run models that
                        # predict only parts of the targets defined by the Problem class.
                        # For example, an autoencoder or pure-video model can run on a gym
                        # problem even if another model is also predicting other things,
                        # like actions or rewards.
                        tf.logging.warning(
                            'No key %s in logits for evaluation.' % k
                        )
                else:
                    eval_metrics[metric_name] = metric_fn(
                        logits, features, features['targets']
                    )
            if isinstance(logits, dict):
                predictions = logits
            else:
                predictions = {'predictions': logits}

            evaluation_hooks = []
            # Create a SummarySaverHook
            eval_dir = os.path.join(
                self.hparams.model_dir,
                self.hparams.get('eval_dir_name', 'eval'),
            )
            eval_summary_hook = tf.train.SummarySaverHook(
                save_steps = 1,
                output_dir = eval_dir,
                summary_op = tf.summary.merge_all(),
            )
            evaluation_hooks.append(eval_summary_hook)

            evaluation_hooks += problem.eval_hooks(features, logits, hparams)

            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                predictions = predictions,
                eval_metric_ops = eval_metrics,
                evaluation_hooks = evaluation_hooks,
                loss = loss,
            )

    def estimator_spec_predict(self, features, use_tpu = False):
        """Constructs `tf.estimator.EstimatorSpec` for PREDICT (inference) mode."""
        decode_hparams = self._decode_hparams
        top_beams = (
            decode_hparams.beam_size if decode_hparams.return_beams else 1
        )
        infer_out = self.infer(
            features,
            beam_size = decode_hparams.beam_size,
            top_beams = top_beams,
            alpha = decode_hparams.alpha,
            decode_length = decode_hparams.extra_length,
            use_tpu = use_tpu,
        )
        if isinstance(infer_out, dict):
            outputs = infer_out['outputs']
            scores = infer_out['scores']
        else:
            outputs = infer_out
            scores = None

        # Workaround for "ValueError: prediction values must be from the default
        # graph" during TPU model exporting.
        # TODO(b/130501786): remove tf.identity once default graph mismatch is fixed
        if use_tpu:
            for name, feature in features.items():
                features[name] = tf.identity(feature)

        inputs = features.get('inputs')
        if inputs is None:
            inputs = features.get('targets')

        predictions = {
            'outputs': outputs,
            'scores': scores,
            'inputs': inputs,
            'targets': features.get('infer_targets'),
        }

        # Pass through remaining features
        for name, feature in features.items():
            if name not in list(predictions.keys()) + ['infer_targets']:
                if name == 'decode_loop_step':
                    continue
                if not feature.shape.as_list():
                    # All features must have a batch dimension
                    batch_size = common_layers.shape_list(outputs)[0]
                    feature = tf.tile(tf.expand_dims(feature, 0), [batch_size])
                predictions[name] = feature

        _del_dict_non_tensors(predictions)

        export_out = {'outputs': predictions['outputs']}
        if 'scores' in predictions:
            export_out['scores'] = predictions['scores']

        if decode_hparams.get('export_extra_infer_outputs'):
            for output in decode_hparams.export_extra_infer_outputs.split(','):
                export_out[output] = infer_out[output]

        # Necessary to rejoin examples in the correct order with the Cloud ML Engine
        # batch prediction API.
        if 'batch_prediction_key' in predictions:
            export_out['batch_prediction_key'] = predictions[
                'batch_prediction_key'
            ]

        export_outputs = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(
                export_out
            )
        }
        if use_tpu:
            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = self.create_eval_host_call()
            else:
                host_call = None

            remove_summaries()

            return contrib.tpu().TPUEstimatorSpec(
                tf.estimator.ModeKeys.PREDICT,
                predictions = predictions,
                host_call = host_call,
                export_outputs = export_outputs,
            )
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.PREDICT,
                predictions = predictions,
                export_outputs = export_outputs,
            )

    def _normalize_body_output(self, body_out):
        if isinstance(body_out, tuple):
            output, losses = body_out
            if isinstance(losses, (list, tuple)):
                losses = {
                    'extra': tf.add_n([tf.reduce_mean(l) for l in losses])
                }
            elif isinstance(losses, dict):
                pass
            else:
                losses = {'extra': tf.reduce_mean(losses)}
        else:
            output = body_out
            losses = {'extra': 0.0}

        return output, losses

    def _summarize_losses(self, losses_dict):
        """Adds `tf.summary`s to all terms in the losses dictionary."""
        if common_layers.should_generate_summaries():
            with tf.name_scope('losses'):
                for loss_name, loss_val in sorted(losses_dict.items()):
                    tf.summary.scalar(loss_name, loss_val)

    def maybe_scheduled_sampling(self, features, logits, losses):
        """Scheduled sampling.

    Performs forward inference again with "targets" feature replaced with values
    sampled from the model.

    This is the identity unless self.hparams.scheduled_sampling_prob > 0
    (default).

    **WARNING**: If hparams.scheduled_sampling_method == "parallel", this is
    not a faithful implementation of scheduled sampling. This implementation
    samples tokens for timestep t condtioned on gold tokens 1...t-1. A proper
    implementation must condition on a mix of gold and sampled tokens. Doing
    so is not efficient for models such like Transformer.

    Args:
      features: {str: Tensor}. Features sharded along batch dimension.
      logits: Tensor. Logits for each shard of data.
      losses: 0-D Tensor or (num: 0-D Tensor, denom: 0-D Tensor). Loss Tensor

    Returns:
      new_logits: Tensor.
      new_losses: {str: loss} where loss is one of (i) a 0-D Tensor or
        (ii) a (num: 0-D Tensor, denom: 0-D Tensor) pair to be used in a
        weighted average.
    """
        hparams = self.hparams
        problem_hparams = self._problem_hparams

        # Only do scheduled sampling if requested.
        if hparams.scheduled_sampling_prob == 0.0:
            return (logits, losses)

        # Only do scheduled sampling on language tasks.
        modality = problem_hparams.modality['targets']
        if modality not in [
            modalities.ModalityType.SYMBOL,
            modalities.ModalityType.SYMBOL_WEIGHTS_ALL,
            modalities.ModalityType.IMAGE,
        ]:
            assert hparams.scheduled_sampling_prob == 0, (
                'Scheduled sampling only applies to ModalityType.(SYMBOL, '
                'SYMBOL_WEIGHTS_ALL, IMAGE). Found {modality}. Set '
                'hparams.scheduled_sampling_prob == 0.0.'
            ).format(modality = modality)
            return (logits, losses)

        # Only do scheduled sampling when training.
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
        if not is_training:
            tf.logging.info(
                'Running in %s mode. Not using scheduled sampling.',
                hparams.mode,
            )
            return (logits, losses)

        # Pad vocabulary if vocab size must be evenly divisible by vocab_divisor.
        vocab_size = problem_hparams.vocab_size['targets']
        assert vocab_size is not None
        assert hparams.vocab_divisor == 1

        # TODO(duckworthd): Move to scheduled_sampling.py.
        def sample(x):
            """Multinomial sampling from a n-dimensional tensor."""
            samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1)
            reshaped_samples = tf.reshape(
                samples, common_layers.shape_list(x)[:-1]
            )
            return tf.to_int32(reshaped_samples)

        # TODO(duckworthd): Move to scheduled_sampling.py.
        def mix_gold_sampled(
            gold_targets, sampled_targets, mixin_prob, i, prev_new_targets
        ):
            """Interleave sampled and gold tokens randomly."""
            # Resample each location iid.
            should_use_sampled_targets = tf.less(
                tf.random_uniform(common_layers.shape_list(sampled_targets)),
                mixin_prob,
            )
            mixed_targets = tf.where(
                should_use_sampled_targets, sampled_targets, gold_targets
            )

            # Reuse sample tokens for earlier timesteps.
            new_targets = tf.where(
                is_later_timestep(gold_targets, i),
                mixed_targets,
                prev_new_targets,
            )
            return new_targets

        # TODO(duckworthd): Move to scheduled_sampling.py.
        def is_later_timestep(x, pass_idx):
            """Constructs mask based on timestep."""
            assert x.shape.ndims == 4, x.shape
            x_shape = tf.shape(x)
            num_timesteps = x_shape[1]
            timesteps = tf.range(num_timesteps)
            timesteps = tf.reshape(timesteps, [1, num_timesteps, 1, 1])
            # The following is a bit untrue. For images, "num_timesteps" actually
            # represents image height, not time. We ignore that fact here.
            timesteps = tf.broadcast_to(timesteps, x_shape)
            return tf.greater_equal(timesteps, pass_idx)

        # TODO(duckworthd): Move to scheduled_sampling.py.
        def parallel_scheduled_sampling_pass(
            i, prev_new_targets, features, logits, mixin_prob
        ):
            """Generate scheduled sampling results."""
            sampled_targets = sample(logits)
            new_targets = mix_gold_sampled(
                features['targets'],
                sampled_targets,
                mixin_prob,
                i,
                prev_new_targets,
            )
            new_targets = tf.stop_gradient(
                new_targets
            )  # Treat new_targets as given.
            new_features = copy.copy(features)
            new_features['targets'] = new_targets
            with tf.variable_scope(tf.get_variable_scope(), reuse = True):
                # Compute bottom() for new_targets.
                #
                # TODO(duckworthd): Only apply bottom to 'new_targets'.
                new_transformed_features = self.bottom(new_features)

                # Compute body.
                with tf.variable_scope('body'):
                    new_body_outputs, new_losses = self._normalize_body_output(
                        self.body(new_transformed_features)
                    )
                assert 'training' not in new_losses

                # Compute top.
                new_logits = self.top(new_body_outputs, new_features)

                # Compute loss. Use original features (== labels).
                if (
                    hparams.mode != tf.estimator.ModeKeys.PREDICT
                    and hparams.mode != 'attack'
                ):
                    new_losses['training'] = self.loss(new_logits, features)
                else:
                    new_losses['training'] = 0.0

            return new_targets, new_logits, new_losses

        tf.logging.info('Using scheduled sampling.')
        tf.logging.info(
            'Warming scheduled sampling up with schedule: %s',
            hparams.scheduled_sampling_warmup_schedule,
        )
        assert (
            hparams.scheduled_sampling_prob == 1.0
        ), 'hparams.scheduled_sampling_prob must be 0 or 1.'

        if hparams.scheduled_sampling_method == 'sequential':
            tf.logging.info('Using SEQUENTIAL scheduled sampling.')
            assert hparams.scheduled_sampling_num_passes == 1, (
                'hparams.scheduled_sampling_num_passes must equal 1 if '
                'doing sequential scheduled sampling.'
            )
            return scheduled_sampling.sequential_scheduled_sampling_for_t2tmodel(
                self, features
            )
        elif hparams.scheduled_sampling_method == 'parallel':
            tf.logging.info('Using PARALLEL scheduled sampling.')
            # TODO(duckworthd): Move this block to scheduled_sampling.py.

            # Gradually increase over a warmup period. Lower numbers mean more gold
            # tokens.
            mixin_prob = scheduled_sampling.inverse_decay_mix_prob(
                hparams.scheduled_sampling_warmup_schedule,
                hparams.scheduled_sampling_gold_mixin_prob,
                hparams.scheduled_sampling_warmup_steps,
            )

            # Apply scheduled sampling over N passes. The logits from the (n-1)-th
            # pass will be mixed with gold tokens for conditioning in the n-th pass.
            assert hparams.scheduled_sampling_num_passes > 0, (
                'hparams.scheduled_sampling_num_passes must be > 0 if '
                'hparams.scheduled_sampling_prob > 0.0'
            )
            new_logits = logits
            new_losses = losses
            prev_new_targets = features['targets']
            for i in range(hparams.scheduled_sampling_num_passes):
                prev_new_targets, new_logits, new_losses = parallel_scheduled_sampling_pass(
                    i, prev_new_targets, features, new_logits, mixin_prob
                )
            return new_logits, new_losses
        else:
            raise ValueError(
                'Unknown scheduled_sampling_method = %s'
                % (hparams.scheduled_sampling_method,)
            )


def _with_timing(fn, msg, silent = False):
    def fn_with_timing(*args, **kwargs):
        start_time = time.time()
        res = fn(*args, **kwargs)
        if not silent:
            log_info(
                'Doing %s took %.3f sec.' % (msg, time.time() - start_time)
            )
        return res

    return fn_with_timing


def create_dummy_vars():
    """Dummy vars for restore to work when not using TPU codepath."""
    var_names = set([v.name for v in tf.global_variables()])
    if 'losses_avg/problem_0/total_loss:0' in var_names:
        return
    with tf.variable_scope('losses_avg'):
        with tf.variable_scope('problem_0'):
            for var_name in ['total', 'extra', 'training']:
                tf.get_variable(
                    '%s_loss' % var_name, initializer = 100.0, trainable = False
                )
    with tf.variable_scope('train_stats'):
        tf.get_variable('problem_0_steps', initializer = 0, trainable = False)


# These metrics are implemented with py_funcs and therefore do no work with TPU
TPU_METRIC_BLACKLIST = set(
    [
        metrics.Metrics.APPROX_BLEU,
        metrics.Metrics.ROUGE_2_F,
        metrics.Metrics.ROUGE_L_F,
        metrics.Metrics.IMAGE_SUMMARY,
    ]
)


def create_tpu_eval_metrics_fn(problem, model_hparams):
    """Create the metrics_fn that TPUEstimatorSpec expects."""

    def reduce_dimensions(predictions, labels):
        """Reduce dimensions for high-dimensional predictions and labels."""
        if len(predictions.get_shape()) > 5:
            predictions_shape = common_layers.shape_list(predictions)
            predictions = tf.reshape(
                predictions,
                [
                    predictions_shape[0],
                    predictions_shape[1],
                    -1,
                    predictions_shape[-1],
                ],
            )
            labels_shape = common_layers.shape_list(labels)
            labels = tf.reshape(labels, [labels_shape[0], labels_shape[1], -1])
        return predictions, labels

    metric_fns = []
    eval_metrics = problem.eval_metric_fns(model_hparams)

    tm = _create_target_modality(problem.get_hparams(model_hparams).modality)
    if isinstance(tm, dict):
        for k, v in six.iteritems(tm):
            weights_fn = modalities.get_weights_fn(v)

            def make_metric_fn(metric_fn):
                """returns a metric_fn."""

                def wrapped_metric_fn(
                    logits, labels, features, weights_fn = weights_fn
                ):
                    kwargs = {}
                    args, _, keywords, _ = inspect.getargspec(metric_fn)
                    if ('features' in args) or keywords:
                        kwargs['features'] = features

                    logits, labels = reduce_dimensions(logits, labels)
                    num, den = metric_fn(
                        logits, labels, weights_fn = weights_fn, **kwargs
                    )
                    return tf.metrics.mean(num, den)

                return wrapped_metric_fn

            for metric, metric_fn in six.iteritems(eval_metrics):
                if metric in TPU_METRIC_BLACKLIST:
                    log_warn(
                        'Skipping eval metric %s in TPU_METRIC_BLACKLIST',
                        metric,
                    )
                    continue
                name = '%s/metrics-%s/%s' % (k, problem.name, metric)
                metric_fns.append((name, make_metric_fn(metric_fn)))
    else:
        weights_fn = modalities.get_weights_fn(tm)

        def make_metric_fn(metric_fn):
            """returns a metric fn."""

            def wrapped_metric_fn(logits, labels, features):
                kwargs = {}
                args, _, keywords, _ = inspect.getargspec(metric_fn)
                if ('features' in args) or keywords:
                    kwargs['features'] = features

                logits, labels = reduce_dimensions(logits, labels)
                num, den = metric_fn(
                    logits, labels, weights_fn = weights_fn, **kwargs
                )
                return tf.metrics.mean(num, den)

            return wrapped_metric_fn

        for metric, metric_fn in six.iteritems(eval_metrics):
            if metric in TPU_METRIC_BLACKLIST:
                log_warn(
                    'Skipping eval metric %s in TPU_METRIC_BLACKLIST', metric
                )
                continue
            name = 'metrics-%s/%s' % (problem.name, metric)
            metric_fns.append((name, make_metric_fn(metric_fn)))

    def all_metrics_fn(**kwargs):
        """Construct metrics dictionary."""

        original_kwargs = _unflatten_dict(
            kwargs, prefixes = ['logits', 'features']
        )
        del kwargs

        logits = original_kwargs['logits']
        labels = original_kwargs['labels']
        features = original_kwargs['features']
        del original_kwargs

        metrics_dict = {}

        for name, fn in metric_fns:
            if isinstance(logits, dict) and isinstance(labels, dict):
                for k, v in six.iteritems(logits):
                    metrics_dict['%s/%s' % (k, name)] = fn(
                        v, labels[k], features
                    )
            elif isinstance(logits, dict):
                tf.logging.warning(
                    'Logits is a dict, but labels is not; only '
                    "evaluating logits['targets'] against labels."
                )
                metrics_dict['%s/%s' % ('targets', name)] = fn(
                    logits['targets'], labels, features
                )
            else:
                metrics_dict[name] = fn(logits, labels, features)

        return metrics_dict

    return all_metrics_fn


def remove_summaries():
    """Remove summaries from the default graph."""
    g = tf.get_default_graph()
    key = tf.GraphKeys.SUMMARIES
    log_debug('Remove summaries %s' % str(g.get_collection(key)))
    del g.get_collection_ref(key)[:]
    assert not g.get_collection(key)


def create_host_call(model_dir):
    """Construct a host_call writing scalar summaries.

  Args:
    model_dir: String containing path to train

  Returns:
    (fn, args) Pair to be called by TPUEstimator as the host_call.
  """
    graph = tf.get_default_graph()
    summaries = graph.get_collection(tf.GraphKeys.SUMMARIES)
    gs_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
    summary_kwargs = collections.OrderedDict()
    for t in summaries:
        # TODO(aidangomez): enable ImageSummary support when we have a faster method
        # see @shibow's comment in cl/202344570
        if t.op.type not in ['ScalarSummary']:
            tf.logging.warn(
                'Ignoring unsupported tf.Summary type %s' % t.op.type
            )
            continue

        name = t.op.name
        tensor = t.op.inputs[1]
        if t.op.type == 'ScalarSummary':
            assert tensor.shape.is_compatible_with([])
            if tensor.dtype == tf.int64:
                tensor = tf.to_int32(tensor)
            summary_kwargs['ScalarSummary' + name] = tf.reshape(tensor, [1])
        elif t.op.type == 'ImageSummary':
            # TODO(aidangomez): as we move to support more types, update
            # common_layers.tpu_safe_image_summary
            if tensor.dtype != tf.float32:
                tf.logging.warn(
                    'Currently T2T on TPU only supports ImageSummary of '
                    'tf.float32-type Tensors. Skipping Tensor '
                    '%s with dtype %s...' % (tensor.name, tensor.dtype)
                )
                continue
            # tensor = tf.to_float(tensor)
            summary_kwargs['ImageSummary' + name] = tensor
    # When no supported summaries are found, don't create host_call. Otherwise,
    # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
    # it, eventually causing hang.
    if not summary_kwargs:
        return None
    summary_kwargs['global_step'] = gs_t
    log_info('summary_kwargs %s' % str(summary_kwargs))

    def host_call_fn(**kwargs):
        """Training host call. Creates summaries for training metrics.

    Args:
      **kwargs: Dict of {str: Tensor} , with `Tensor` of shape `[batch]`. Must
        contain key "global_step" with value of current global_step Tensor.

    Returns:
      List of summary ops to run on the CPU host.
    """
        gs = tf.to_int64(kwargs.pop('global_step')[0])
        with contrib.summary().create_file_writer(model_dir).as_default():
            with contrib.summary().always_record_summaries():
                # We need to use tf.contrib.summary in order to feed the `step`.
                for name, value in sorted(six.iteritems(kwargs)):
                    if name.startswith('ScalarSummary'):
                        name = name[len('ScalarSummary') :]
                        contrib.summary().scalar(
                            name, tf.reduce_mean(tf.to_float(value)), step = gs
                        )
                    elif name.startswith('ImageSummary'):
                        name = name[len('ImageSummary') :]
                        contrib.summary().image(name, value, step = gs)

                return contrib.summary().all_summary_ops()

    return (host_call_fn, summary_kwargs)


def _del_dict_non_tensors(d):
    for k in list(d.keys()):
        if not isinstance(d[k], tf.Tensor):
            del d[k]


class DummyVariableStore(object):
    @contextlib.contextmanager
    def as_default(self):
        yield


def create_eager_var_store():
    if tf.executing_eagerly():
        return variable_scope.EagerVariableStore()
    else:
        return DummyVariableStore()


def average_sharded_losses(sharded_losses):
    """Average losses across datashards.

  Args:
    sharded_losses: list<dict<str loss_name, Tensor loss>>. The loss
      can be a single Tensor or a 2-tuple (numerator and denominator).

  Returns:
    losses: dict<str loss_name, Tensor avg_loss>
  """
    losses = {}
    for loss_name in sorted(sharded_losses[0]):
        all_shards = [
            shard_losses[loss_name] for shard_losses in sharded_losses
        ]
        if isinstance(all_shards[0], tuple):
            sharded_num, sharded_den = zip(*all_shards)
            mean_loss = tf.add_n(sharded_num) / tf.maximum(
                tf.cast(1.0, sharded_den[0].dtype), tf.add_n(sharded_den)
            )
        else:
            mean_loss = tf.reduce_mean(all_shards)

        losses[loss_name] = mean_loss
    return losses


def summarize_features(features, num_shards = 1):
    """Generate summaries for features."""
    if not common_layers.should_generate_summaries():
        return

    with tf.name_scope('input_stats'):
        for (k, v) in sorted(six.iteritems(features)):
            if (
                isinstance(v, tf.Tensor)
                and (v.get_shape().ndims > 1)
                and (v.dtype != tf.string)
            ):
                tf.summary.scalar('%s_batch' % k, tf.shape(v)[0] // num_shards)
                tf.summary.scalar('%s_length' % k, tf.shape(v)[1])
                nonpadding = tf.to_float(tf.not_equal(v, 0))
                nonpadding_tokens = tf.reduce_sum(nonpadding)
                tf.summary.scalar('%s_nonpadding_tokens' % k, nonpadding_tokens)
                tf.summary.scalar(
                    '%s_nonpadding_fraction' % k, tf.reduce_mean(nonpadding)
                )


_already_logged = set()


def _eager_log(level, *args):
    if tf.executing_eagerly() and args in _already_logged:
        return
    _already_logged.add(args)
    getattr(tf.logging, level)(*args)


def log_debug(*args):
    _eager_log('debug', *args)


def log_info(*args):
    _eager_log('info', *args)


def log_warn(*args):
    _eager_log('warn', *args)


def _compose_custom_getters(getter_a, getter_b):
    """Compose two custom getters.

  Example use:
  tf.get_variable_scope().set_custom_getter(
    compose_custom_getters(tf.get_variable_scope().custom_getter, new_getter))

  This composes getters in the same way as creating a new variable scope with
  the new_getter, but it does not actually create a new variable scope.

  Args:
    getter_a: a custom getter - generally from the existing variable scope.
    getter_b: a custom getter

  Returns:
    a custom getter
  """
    if not getter_a:
        return getter_b
    if not getter_b:
        return getter_a

    def getter_fn(getter, *args, **kwargs):
        return getter_b(functools.partial(getter_a, getter), *args, **kwargs)

    return getter_fn


def set_custom_getter_compose(custom_getter):
    """Set a custom getter in the current variable scope.

  Do not overwrite the existing custom getter - rather compose with it.

  Args:
    custom_getter: a custom getter.
  """
    tf.get_variable_scope().set_custom_getter(
        _compose_custom_getters(
            tf.get_variable_scope().custom_getter, custom_getter
        )
    )


def _create_target_modality(modality_dict):
    # TODO(trandustin): We require this in order to apply methods utilized
    # differently for modalities which are "targets"
    # (e.g., modality.target_bottom). In the future, remove need for this
    # behavior.
    return {
        k: v
        for k, v in six.iteritems(modality_dict)
        if 'target' in k
        and k != 'targets_segmentation'
        and k != 'targets_position'
    }


def initialize_from_ckpt(ckpt_dir, hparams):
    """Initialize variables from given directory."""
    model_dir = hparams.get('model_dir', None)
    already_has_ckpt = (
        model_dir and tf.train.latest_checkpoint(model_dir) is not None
    )
    if already_has_ckpt:
        return

    tf.logging.info('Checkpoint dir: %s', ckpt_dir)
    reader = contrib.framework().load_checkpoint(ckpt_dir)
    variable_map = {}
    for var in contrib.framework().get_trainable_variables():
        var_name = var.name.split(':')[0]
        if reader.has_tensor(var_name):
            tf.logging.info('Loading variable from checkpoint: %s', var_name)
            variable_map[var_name] = var
        else:
            tf.logging.info(
                'Cannot find variable in checkpoint, skipping: %s', var_name
            )
    tf.train.init_from_checkpoint(ckpt_dir, variable_map)
